{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c408367e",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License.\n",
    "\n",
    "# 2D Model Inference on a 3D Volume  "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a8681db2",
   "metadata": {},
   "source": [
    "Usecase: A 2D Model, such as, a 2D segmentation U-Net operates on 2D input which can be slices from a 3D volume (for example, a CT scan). \n",
    "\n",
    "After editing sliding window inferer as described in this tutorial, it can handle the entire flow as shown:\n",
    "![image](../figures/2d_inference_3d_input.png)\n",
    "\n",
    "The input is a *3D Volume*, a *2D model* and the output is a *3D volume* with 2D slice predictions aggregated.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/2d_inference_3d_volume.ipynb)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5c3a44a5",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2e1b91f",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a81870a3",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9cd1b08",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from monai.config import print_config\n",
    "from monai.inferers import SliceInferer\n",
    "from monai.networks.nets import UNet\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85f00a47",
   "metadata": {},
   "source": [
    "## SliceInferer\n",
    "The simplest way to achieve this functionality is to extend the `SlidingWindowInferer` in `monai.inferers`. This is made available as `SliceInferer` in MONAI (https://docs.monai.io/en/latest/inferers.html#sliceinferer)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb0a63dd",
   "metadata": {},
   "source": [
    "## Usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85b15305",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:00<00:00, 107.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Axial Inferer Output Shape:  torch.Size([1, 1, 64, 256, 256])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:01<00:00, 177.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Coronal Inferer Output Shape:  torch.Size([1, 1, 64, 256, 256])\n"
     ]
    }
   ],
   "source": [
    "# Create a 2D UNet with randomly initialized weights for testing purposes\n",
    "\n",
    "# 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units\n",
    "net = UNet(\n",
    "    spatial_dims=2,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    channels=(4, 8, 16),\n",
    "    strides=(2, 2),\n",
    "    num_res_units=2,\n",
    ")\n",
    "\n",
    "# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n",
    "input_volume = torch.ones(1, 1, 64, 256, 256)\n",
    "\n",
    "# Create an instance of SliceInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n",
    "axial_inferer = SliceInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1, progress=True)\n",
    "\n",
    "output = axial_inferer(input_volume, net)\n",
    "\n",
    "# Output is a 3D volume with 2D slices aggregated\n",
    "print(\"Axial Inferer Output Shape: \", output.shape)\n",
    "# Create an instance of SliceInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n",
    "coronal_inferer = SliceInferer(\n",
    "    roi_size=(64, 256),\n",
    "    sw_batch_size=1,\n",
    "    spatial_dim=1,  # Spatial dim to slice along is added here\n",
    "    cval=-1,\n",
    "    progress=True,\n",
    ")\n",
    "\n",
    "output = coronal_inferer(input_volume, net)\n",
    "\n",
    "# Output is a 3D volume with 2D slices aggregated\n",
    "print(\"Coronal Inferer Output Shape: \", output.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2596d86",
   "metadata": {},
   "source": [
    "Note that with `axial_inferer` and `coronal_inferer`, the number of inference iterations is 64 and 256 respectively."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
